#!/usr/bin/env python3
from __future__ import annotations

from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Callable, List, Optional

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import wandb
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
import pandas as pd

colors = [
    '#23aaff',
    '#66c56c',
    '#ff7575',
    '#a6ddff',
    '#97c29a',
    '#ffc9c9'
]

colors = [
    '#f44336', #red
    '#8fce00', #green
    '#f1c232', #yellow
    '#2986cc', #blue
    '#070000', #black
    '#a64d79', #purple
    '#999999', #grey
    '#38761d' #darkgreen
]

cmap = mpl.colormaps['Dark2']
# colors = [cmap(i) for i in range(20)]

num_seeds = 5


def skip_column(column: str):
    if column == 'step' or column.endswith('__MAX') or column.endswith('__MIN'):
        return True


def create_plot(data_frame, title: str, xlabel: str, ylabel: str, fname: Path | str, group2legend: dict, group2color: Optional[dict] = None, hbars: Optional[List[float]] = None, dry_run: bool = False):
    df = data_frame
    print('create_plot -- keys', df.keys())

    # load_matplotlibrc()
    fig, ax = plt.subplots()
    columns = df.keys()
    counter = 0
    for column in sorted(columns):
        if skip_column(column):
            continue

        xs = df['step']
        ys = df[column]
        ymaxs = df[column + '__MAX']
        ymins = df[column + '__MIN']

        color = colors[counter % len(colors)] if group2color is None else group2color[column]

        if group2legend is None:
            ax.plot(xs, ys, '-', color=color)
        else:
            ax.plot(xs, ys, '-', color=color, label=group2legend[column])
        ax.fill_between(xs, ymins, ymaxs, color=color, alpha=0.17)
        # ax.fill_between(xs, ymins, ymaxs, color=colors[counter % len(colors)], alpha=0.07)
        counter += 1

    if hbars is not None:
        for hbar in hbars:
            if hbar >0:
                plt.axhline(y=hbar, color='#a88132', linestyle='--', label='best oracle' if group2legend is not None else None)

    # ax.xaxis.set_major_locator(MultipleLocator(20))
    # ax.xaxis.set_minor_locator(MultipleLocator(10))
    # ax.yaxis.set_major_locator(MultipleLocator(20))
    # ax.xaxis.grid(color=(.7, .7, .7), which='both', linestyle='--', linewidth=0.5)
    # ax.yaxis.grid(color=(.7, .7, .7), which='both', linestyle='--', linewidth=0.5)
    # ax.tick_params(axis='x', which='both')
    # ax.yaxis.set_minor_locator(AutoMinorLocator())

    if title:
        plt.title(title)

    # plt.xlim((0, 100))
    # plt.ylim((0, 120))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    ax.legend(loc='upper left',fontsize=18)
    # Legend
    # plt.legend(loc='right', frameon=True, framealpha=1)
    # if add_legend:
    #     leg = ax.legend(fontsize=22)
    #     # leg = ax.legend()
    #     for line in leg.get_lines():
    #         line.set_linewidth(2.0)


    ## Use this if you want to reorder legend
    #get handles and labels
    # handles, labels = ax.get_legend_handles_labels()
    # order = [0, 2, 1, 3, 4, 5]
    # ax.legend([handles[idx] for idx in order], [labels[idx] for idx in order], fontsize=12)
    # ax.legend(fontsize=12)

    # plt.tight_layout()
    if not dry_run:
        plt.savefig(fname)
        print(f'>>> The plot is saved to {fname}')
    else:
        print(f'>>> The plot would be saved to {fname} (dry-run is True!)')


def fetch_expert_vals(run) -> dict[str, float]:
    assert 'load_expert_step' in run.config
    prep_experts_vals = defaultdict(list)
    prep_experts = [f'prep/expert-{exp_step}-eval/returns_mean' for exp_step in sorted(run.config['load_expert_step'], reverse=True)]
    print('prep_experts', prep_experts)

    for i, prep_expert in enumerate(prep_experts):
        entry = run.history(keys=[prep_expert])
        if prep_expert in entry:
            prep_experts_vals[prep_expert].append(float(entry[prep_expert]))

    # Aggregate prep_experts_vals
    for key in prep_experts_vals.keys():
        prep_experts_vals[key] = np.mean(prep_experts_vals[key])

    return prep_experts_vals


def main(runs, plot_name, config, ext='.pdf', force: bool = False, dry_run: bool = False, use_stderr: bool = False,best_expert=100):
    # NOTE:
    # Based on xkey and ykey, it fetches xvals and yvals from runs.
    # Store xvals, yvals separately for each run
    xkey, ykey = config['xkey'], config['ykey']
    group_keys = config['group_keys']
    run_data = [{} for _ in runs]

    plot_dir = Path(config['plot_dir'])

    properties = []
    if not config.get('show_title', False):
        properties.append('notitle')
    if not config.get('show_legend', True):
        properties.append('nolegend')
    plot_dir = plot_dir / '-'.join(properties)

    if use_stderr:
        # Update plot_dir
        plot_dir = plot_dir / 'stderr'
    plot_dir.mkdir(parents=True, exist_ok=True)
    fname = plot_dir / (plot_name + ext)



    if fname.is_file():
        # Skip generating file unless force=True,
        print(f'The file to generate: {fname} already exists!')
        if force:
            print('But `force` is True! Regenerating the plot...')
        else:
            return

    groups = ['-'.join(str(run.config[key]) for key in group_keys) for run in runs]
    print('groups', groups)

    print('list of runs', list(runs))
    print('num of runs', len(list(runs)))

    for run_idx, (run, group, run_d) in enumerate(zip(runs, groups, run_data)):  # aggregation or separate lines
        print(f'run-idx: {run_idx}\trun-id: {run.id}')
        run_d['run-id'] = run.id
        run_d['group'] = group

        # TODO: Need to read from experts_info

        # if run.config['algorithm'] != 'pg-gae':
        #     # NOTE: experts_info looks like this:
        #     # [['ppo', '/lops/experts/cheetah-run-v1/ppo/alops-pfrl/5u59suv0/100000_checkpoint'], ['ppo', '/lops/experts/cheetah-run-v1/ppo/alops-pfrl/5u59suv0/200000_checkpoint'], ['ppo', '/lops/experts/cheetah-run-v1/ppo/alops-pfrl/5u59suv0/300000_checkpoint']]
        #     run_d['experts_info'] = run.config['experts_info']
        #     run_d['expert_names'] = []

        # Get expert values
        if run.config['algorithm'] != 'pg-gae':
            expert_names = []
            for expert_info in run.config['experts_info']:
                policy, path = expert_info
                step = int(Path(path).stem.split('_')[0])
                expert_names.append(f'{policy}-{step / 1000:.0f}k')
            print('expert names', expert_names)
            run_d['expert_names'] = deepcopy(expert_names)

        rows = [row for idx, row in run.history(keys=[xkey, ykey]).iterrows()]
        run_d['xvals'] = [row[xkey] for row in rows]
        run_d['yvals'] = [row[ykey] for row in rows]
        print(run_d)
        # exit()


    # Aggregate prep_experts_vals
    # for key in prep_experts_vals.keys():
    #     prep_experts_vals[key] = np.mean(prep_experts_vals[key])

    # NOTE: I assume run_d['expert_names'] are identical within the plot

    # Get the best expert score
    # df = pd.read_csv(Path('../pretraining/experts-summary.csv'))
    # for run in runs:
    #     expert_vals = []
    #     if run.config['algorithm'] != 'pg-gae':
    #         paths = [path for _, path in run.config['experts_info']]

    #         # Lookup the performance from 'alops/scripts/pretraining/experts-summary.csv' based on path
    #         for path in paths:
    #             row = df.loc[df['path'] == path]
    #             expert_val = float(row['score'].item().split('&plusmn;')[0])
    #             expert_vals.append(expert_val)
    #         break  # expert_val should be identical across runs within the same group.

    # print('prep_expert_vals', prep_experts_vals)
    # hbars = list(prep_experts_vals.values())

    # Special handling!
    # expert_vals = np.asarray([run_d['expert_vals'] for run, run_d in zip(runs, run_data) if run.config['algorithm'] != 'pg-gae'])
    # expert_vals = expert_vals.mean(0)

    if not list(runs):
        raise RuntimeError('Hmm, I guess the runs is being empty...', list(runs))

    # Aggregate results based on their group
    column2arr = {'step': run_d['xvals']}
    for group in list(set(groups)):
        # Find the run_data that has the matching group name
        filtered_run_data = [run_d for run_d in run_data if run_d['group'] == group]

        # if len(filtered_run_data) != num_seeds:
        #     print('ids', [run_d['run-id'] for run_d in filtered_run_data])
        #     print('group', group)
        #     raise RuntimeError(f'Hmm len(filtered_run_data) = {len(filtered_run_data)} != {num_seeds}')

        # Aggregate yvals within filtered_run_data
        _yvals = [run_d['yvals'] for run_d in filtered_run_data]
        min_length = min([len(seq) for seq in _yvals])
        max_length = max([len(seq) for seq in _yvals])
        if min_length != max_length:
            print(f'WARN: some yvals is shorter than others: {min_length} vs {max_length}')

        yvals = np.asarray([run_d['yvals'][:min_length] for run_d in filtered_run_data])
        means, stds = yvals.mean(0), yvals.std(0)

        if use_stderr:
            stds = stds / np.sqrt(num_seeds)

        maxs, mins = means + stds, means - stds
        column2arr[group] = means
        column2arr[group + '__MIN'] = mins
        column2arr[group + '__MAX'] = maxs

    # TODO: Check the lengths of column2arr values, and cut off everything based on the minimal length ones
    min_length = min(len(seq) for seq in column2arr.values())
    max_length = max(len(seq) for seq in column2arr.values())
    if min_length != max_length:
        print(f'WARN: some columns are shorter than others: {min_length} vs {max_length}')
        for column in column2arr.keys():
            column2arr[column] = column2arr[column][:min_length]  # Cutoff

    df = pd.DataFrame.from_dict(column2arr)
    hbar_key = config.get('hbar', None)

    if config.get('show_title', True):
        # title = '\n'.join((plot_name, f"exp-{', '.join(['{:.1f}'.format(val) for val in expert_vals])}"))
        title = '\n'.join((plot_name, f"exp-{', '.join([val for val in expert_names])}"))
        if 'extra_txt' in config:
            title += '\n' + config['extra_txt']
    else:
        title = None

    if config.get('show_legend', True):
        group2legend = config['group2legend']
    else:
        group2legend = None

    create_plot(df,
                title=title,
                xlabel=config['xlabel'],
                ylabel=config['ylabel'],
                fname=fname,
                # hbars=[max(run_d[hbar_key] for run, run_d in zip(runs, run_data) if run.config['algorithm'] != 'pg-gae')] if hbar_key is not None else None,
                # hbars=[max(expert_vals)] if hbar_key is not None else None,
                hbars=[best_expert],
                group2color=config.get('group2color', None),
                group2legend=group2legend,
                dry_run=dry_run
                )
